# 暴力 骗分代码

def main():
    n, m = map(int, input().split())
    g = []
    for _ in range(n):
        g.append(list(map(int, input().split())))
    ans = 0
    dic = [(1,1),(1,-1)]
    for i in range(n):
        for j in range(m):
            r = 0
            for x,y in dic:
                dx,dy = i,j
                while 0 <= dx + x < n and 0 <= dy + y < m:
                    dx += x
                    dy += y
                    if g[dx][dy] == g[i][j]:
                        r += 1
            ans += r
    print(ans * 2)
    return

main()